#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import time
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Input, AveragePooling2D
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.layers import LeakyReLU
import DonaldDuckDataset
from DonaldDuckModel import DonaldDuckModel

def lr_schedule(epoch):
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    print('Learning rate: ', lr)
    return lr

def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):
    """2D Convolution-Batch Normalization-Activation stack builder

    # Arguments
        inputs (tensor): input tensor from input image or previous layer
        num_filters (int): Conv2D number of filters
        kernel_size (int): Conv2D square kernel dimensions
        strides (int): Conv2D square stride dimensions
        activation (string): activation name
        batch_normalization (bool): whether to include batch normalization
        conv_first (bool): conv-bn-activation (True) or
            bn-activation-conv (False)

    # Returns
        x (tensor): tensor as input to the next layer
    """
    conv = Conv2D(num_filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same',
                  kernel_initializer='he_normal',
                  kernel_regularizer=l2(1e-4))

    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x

class DonaldDuckCNN(DonaldDuckModel):
    def setModel(
            self,
            conv_layers_num=3,
            filters=32,
            kernel_size=(3, 3),
            name='CNN'
    ):
        self.name=name+'_'+self.dataset.name+'_'+str(conv_layers_num)+'_'+str(filters)
        self.setArchitecture(
            conv_layers_num=conv_layers_num,
            filters=filters,
            kernel_size=kernel_size,
            name=name
        )
        self.model.summary()

    def convLayer(self, filters, kernel_size, dropout=False, pooling=False):
        layer=Sequential()
        layer.add(
            Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                padding='same'
            )
        )
        layer.add(BatchNormalization())
        layer.add(Activation(self.act_func))
        if dropout:
            layer.add(Dropout())
        if pooling:
            layer.add(MaxPooling2D(pool_size=(2,2)))
        return layer

    def setArchitecture(
            self,
            conv_layers_num=3,
            filters=32,
            kernel_size=(3, 3),
            name='CNN'
    ):
        inputs = Input(shape=self.input_shape)
        conv_layers=[]
        for cln in range(conv_layers_num):
            conv_layers.append(
                self.convLayer(
                    filters=filters,
                    kernel_size=kernel_size,
                    pooling=(cln==conv_layers_num-1 or cln%3+1==3)
                )
            )
            if cln%2+1==2:
                filters=filters*2
            if filters>512:
                filters=512
        x=inputs
        for conv_layer in conv_layers:
            x=conv_layer(x)
        x=Flatten()(x)
        x=Dense(256, activation=self.act_func)(x)
        x=Dense(256, activation=self.act_func)(x)
        x=Dense(self.num_classes)(x)
        outputs=Activation('softmax')(x)
        model=Model(inputs=inputs, outputs=outputs, name=name)

        self.opt = Adam(learning_rate=self.learning_rate)
        self.loss = 'categorical_crossentropy'
        self.metrics = ['accuracy']

        self.model = model

class DonaldDuckVGG16(DonaldDuckModel):
    def setModel(self):
        self.name='VGG16'+'_'+self.dataset.name
        self.setArchitecture()
#        self.model.summary()

    def convLayer(self, filters, kernel_size=(3,3), pooling=False, dropout=False):
        layer=Sequential()
        layer.add(
            Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                padding='same'
            )
        )
        layer.add(BatchNormalization())
        layer.add(Activation('relu'))
        if dropout:
            layer.add(Dropout(0.5))
        if pooling:
            layer.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
        return layer

    def setArchitecture(self):
        inputs = Input(shape=self.input_shape)
        x=inputs

        x=self.convLayer(filters=64,pooling=False, dropout=True)(x)
        x=self.convLayer(filters=64,pooling=True)(x)

        x=self.convLayer(filters=128,pooling=False, dropout=True)(x)
        x=self.convLayer(filters=128,pooling=True)(x)

        x=self.convLayer(filters=256,pooling=False, dropout=True)(x)
        x=self.convLayer(filters=256,pooling=False)(x)
        x=self.convLayer(filters=256,pooling=True)(x)

        x=self.convLayer(filters=512,pooling=False, dropout=True)(x)
        x=self.convLayer(filters=512,pooling=False)(x)
        x=self.convLayer(filters=512,pooling=True)(x)

        x=self.convLayer(filters=512,pooling=False, dropout=True)(x)
        x=self.convLayer(filters=512,pooling=False)(x)
        x=self.convLayer(filters=512,pooling=True)(x)

        x=Flatten()(x)
        x=Dense(512, activation='relu')(x)
        # x=Dense(4096, activation='relu')(x)
        x=Dense(self.num_classes)(x)
        outputs=Activation('softmax')(x)
        model=Model(inputs=inputs, outputs=outputs, name='VGG')

        self.opt = Adam(learning_rate=self.learning_rate)
        self.loss = 'categorical_crossentropy'
        self.metrics = ['accuracy']

        self.model = model

if __name__=='__main__':

    cnn1=DonaldDuckCNN(DonaldDuckDataset.MNIST())
    cnn1.setModel(
        conv_layers_num = 3,
        filters = 32,
        kernel_size = (3, 3)
    )